{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "private_outputs": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "gpuClass": "standard"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Run both Whisper and Voicevox in a single Colab Session\n",
        "\n",
        "To enable GPU in this notebook, select Runtime -> Change runtime type in the Menu bar. Under Hardware Accelerator, select GPU.\n",
        "\n",
        "Then, scroll to the Configuration [cell](#scrollTo=8WIVDY-V-kVw&line=1&uniqifier=1) and update it with your ngrok authentication token.\n",
        "\n",
        "To run, select Runtime -> Run all. Go to this [cell](#scrollTo=5M_2NlAXB89F&line=1&uniqifier=1) and read the instructions on how to update your `.env` file."
      ],
      "metadata": {
        "id": "4cFX5ElsYbPB"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Check GPU"
      ],
      "metadata": {
        "id": "e_2rOjTRCuGd"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!nvidia-smi"
      ],
      "metadata": {
        "id": "2uEbLKtCCtcq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Install Deps"
      ],
      "metadata": {
        "id": "9_RqReHfCwPU"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install flask -q\n",
        "!pip install pyngrok -q\n",
        "!pip install git+https://github.com/openai/whisper.git -q\n",
        "!pip install requests -q\n",
        "!pip install flask-cors -q"
      ],
      "metadata": {
        "id": "vKClg4iV3tnN"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "njuoTw33hJyh"
      },
      "outputs": [],
      "source": [
        "# VoiceVox Engine\n",
        "ENGINE_VER = '0.14.2'\n",
        "ZIP_FILENAME = f'voicevox_engine-linux-nvidia-{ENGINE_VER}.7z.001'\n",
        "DOWNLOAD_LINK = f'https://github.com/VOICEVOX/voicevox_engine/releases/download/{ENGINE_VER}/{ZIP_FILENAME}'\n",
        "\n",
        "!wget $DOWNLOAD_LINK\n",
        "!7z x $ZIP_FILENAME -y\n",
        "!rm $ZIP_FILENAME\n",
        "\n",
        "!git clone https://github.com/VOICEVOX/voicevox_engine -q\n",
        "\n",
        "# Install deps, ignoring python version\n",
        "!pip install -r <(sed 's/;.*//' voicevox_engine/requirements.txt) -q\n",
        "!pip install numpy==1.22 -q"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Configuration\n",
        "\n",
        "Please set the NGROK auth token to access the tunnel."
      ],
      "metadata": {
        "id": "kY-BMho4soa9"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "NGROK_AUTH_TOKEN = '' #@param {type:'string'}\n",
        "TRANSLATE_FILENAME = 'translate.wav' #@param {type:'string'}\n",
        "TRANSCRIBE_FILENAME = 'transcribe.wav' #@param {type:'string'}\n",
        "WHISPER_MODEL = 'small' #@param ['tiny', 'base', 'small', 'medium', 'large']\n",
        "VOICEVOX_URL = 'http://localhost:50021' #@param {type:'string'}\n",
        "CHUNK_SIZE = 4096 #@param {type:'integer'}"
      ],
      "metadata": {
        "id": "8WIVDY-V-kVw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Main Code"
      ],
      "metadata": {
        "id": "U4I-fc5kCKyS"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from flask import Flask, request, Response\n",
        "import json\n",
        "import whisper\n",
        "import subprocess\n",
        "import requests\n",
        "from pyngrok import ngrok\n",
        "from flask_cors import CORS\n",
        "import os\n",
        "from urllib.parse import urlencode\n",
        "\n",
        "model = whisper.load_model(WHISPER_MODEL)"
      ],
      "metadata": {
        "id": "xJpGDX7t6LA_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Update .env file Instructions\n",
        "\n",
        "After the cell below has started running, copy the public url provided by ngrok and update both WHISPER_BASE_URL and VOICEVOX_BASE_URL in your `.env` file. Below is an example output that you will see. \n",
        "\n",
        "```\n",
        "NgrokTunnel: \"http://f9a8-34-73-238-198.ngrok.io\" -> \"http://localhost:5000\"\n",
        " * Serving Flask app '__main__'\n",
        " * Debug mode: off\n",
        "INFO:werkzeug:WARNING: This is a development server. Do not use it in a production deployment. Use a production WSGI server instead.\n",
        " * Running on http://127.0.0.1:5000\n",
        "INFO:werkzeug:Press CTRL+C to quit\n",
        "```\n",
        "\n",
        "DO NOT use this url, use the URL provided by the actual output from running the cell below. In this example, you will update your WHISPER_BASE_URL and VOICEVOX_BASE_URL variable with:\n",
        "\n",
        "```\n",
        "WHISPER_BASE_URL=http://f9a8-34-73-238-198.ngrok.io\n",
        "VOICEVOX_BASE_URL=http://f9a8-34-73-238-198.ngrok.io\n",
        "```\n",
        "\n",
        "This url will change every time you rerun this cell, so remember to update your `.env` file when that happens."
      ],
      "metadata": {
        "id": "GZiFjH5QZFrd"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "app = Flask(__name__)\n",
        "ngrok.set_auth_token(NGROK_AUTH_TOKEN)\n",
        "CORS(app)\n",
        "\n",
        "@app.route('/', methods=['GET'])\n",
        "def test():\n",
        "  response = {'status':'OK','message':'Test'}\n",
        "  return Response(json.dumps({'status':'OK','message':'Test'}), mimetype='application/json')\n",
        "\n",
        "@app.route('/speakers', methods=['GET'])\n",
        "def speakers():\n",
        "  res = requests.get(f\"{VOICEVOX_URL}/speakers\")\n",
        "\n",
        "  return res.json()\n",
        "\n",
        "# Raw\n",
        "@app.route('/audio_query', methods=['POST'])\n",
        "def audio_query():\n",
        "  try:\n",
        "    params_encoded = urlencode(request.args)\n",
        "    r = requests.post(f'{VOICEVOX_URL}/audio_query?{params_encoded}')\n",
        "    return r.json()\n",
        "\n",
        "  except Exception:\n",
        "    return Response(json.dumps({ 'message': 'Failed to request audio_query', 'json': r.json(), 'status': 'Server Error' }), mimetype='application/json', status=500)\n",
        "\n",
        "@app.route('/synthesis', methods=['POST'])\n",
        "def synthesis():\n",
        "  json = request.get_json()\n",
        "  params_encoded = urlencode(request.args)\n",
        "  r = requests.post(f'{VOICEVOX_URL}/synthesis?{params_encoded}', json=json)\n",
        "\n",
        "  return Response(r.content, mimetype='audio/wav')\n",
        "\n",
        "# All in one\n",
        "@app.route('/tts', methods=['POST'])\n",
        "def tts():\n",
        "  text = request.args.get('text')\n",
        "  speaker = int(request.args.get('speaker') or '5')\n",
        "\n",
        "  if (text is None):\n",
        "    return json.dumps({ 'message': 'No text', 'status': 'BAD_REQUEST' }), 400\n",
        "\n",
        "  speed_scale = float(request.args.get('speed_scale') or '1.7')\n",
        "  volume_scale = float(request.args.get('volume_scale') or '4.0')\n",
        "  intonation_scale = float(request.args.get('intonation_scale') or '1.5')\n",
        "  pre_phoneme_length = float(request.args.get('pre_phoneme_length') or '1.0')\n",
        "  post_phoneme_length = float(request.args.get('post_phoneme_length') or '1.0')\n",
        "\n",
        "  params_encoded = urlencode({'text': text, 'speaker': speaker})\n",
        "  r = requests.post(f'{VOICEVOX_URL}/audio_query?{params_encoded}')\n",
        "\n",
        "  if r.status_code == 404:\n",
        "    return Response(json.dumps({ 'message': 'Failed to request audio_query', 'json': r.json(), 'status': 'Server Error' }), mimetype='application/json', status=500)\n",
        "\n",
        "  query = r.json()\n",
        "  query['speedScale'] = speed_scale\n",
        "  query['volumeScale'] = volume_scale\n",
        "  query['intonationScale'] = intonation_scale\n",
        "  query['prePhonemeLength'] = pre_phoneme_length\n",
        "  query['postPhonemeLength'] = post_phoneme_length\n",
        "\n",
        "  params_encoded = urlencode({'speaker': speaker})\n",
        "  r = requests.post(f'{VOICEVOX_URL}/synthesis?{params_encoded}', json=query)\n",
        "\n",
        "  return Response(r.content, mimetype='audio/wav')\n",
        "\n",
        "# Whisper speech-to-text endpoints\n",
        "@app.route('/asr', methods=['POST'])\n",
        "def asr():\n",
        "  task = request.args.get('task') or 'transcribe'\n",
        "  language = request.args.get('language') or 'ja'\n",
        "\n",
        "  if task == 'transcribe':\n",
        "    if (request.content_type.startswith('multipart/form-data')):\n",
        "      audio_data = request.files['audio_file']\n",
        "\n",
        "      if (audio_data is None):\n",
        "        return Response(json.dumps({ 'message': '\"file\" is missing on form data' }), mimetype='application/json', status=422)\n",
        "\n",
        "      audio_data.save(TRANSCRIBE_FILENAME)\n",
        "\n",
        "    else:\n",
        "      with open(TRANSCRIBE_FILENAME, \"bw\") as f:\n",
        "        while True:\n",
        "          chunk = request.stream.read(CHUNK_SIZE)\n",
        "\n",
        "          if len(chunk) == 0:\n",
        "              break\n",
        "\n",
        "          f.write(chunk)\n",
        "\n",
        "        f.close()\n",
        "\n",
        "    result = model.transcribe(TRANSCRIBE_FILENAME)\n",
        "    return Response(json.dumps(result), mimetype='application/json')\n",
        "\n",
        "  elif task == 'translate':\n",
        "    if (request.content_type.startswith('multipart/form-data')):\n",
        "      audio_data = request.files['audio_file']\n",
        "\n",
        "      if (audio_data is None):\n",
        "        return Response(json.dumps({ 'message': '\"file\" is missing on form data' }), mimetype='application/json', status=422)\n",
        "\n",
        "      audio_data.save(TRANSLATE_FILENAME)\n",
        "      \n",
        "    else:\n",
        "      with open(TRANSLATE_FILENAME, 'bw') as f:\n",
        "        while True:\n",
        "          chunk = request.stream.read(CHUNK_SIZE)\n",
        "\n",
        "          if len(chunk) == 0:\n",
        "              break\n",
        "\n",
        "          f.write(chunk)\n",
        "\n",
        "        f.close()\n",
        "\n",
        "    result = model.transcribe(TRANSLATE_FILENAME, language=language, task='translate')\n",
        "    return Response(json.dumps(result), mimetype='application/json')\n",
        "  \n",
        "  else:\n",
        "    return Response(json.dumps({ 'message': 'Unknown task', 'status': 'Bad Request' }), mimetype='application/json', status=400)\n",
        "\n",
        "def main():\n",
        "  # Start voicevox\n",
        "  sub = subprocess.Popen(\n",
        "    \"python /content/voicevox_engine/run.py --voicevox_dir='linux-nvidia' --use_gpu --allow_origin * --cors_policy_mode all\",\n",
        "    shell=True,\n",
        "    stdout=subprocess.PIPE\n",
        "  )\n",
        "\n",
        "  # Open tunnel\n",
        "  http_tunnel = ngrok.connect(5000)\n",
        "  print(http_tunnel)\n",
        "\n",
        "  # Run app\n",
        "  app.run()\n",
        "\n",
        "if __name__ == '__main__':\n",
        "    try:\n",
        "        main()\n",
        "    except KeyboardInterrupt:\n",
        "        print('ded')"
      ],
      "metadata": {
        "id": "5M_2NlAXB89F"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Leave the above cell running and the tab open\n",
        "\n",
        "This is to ensure the runtime does not disconnect and shut down the server. \n",
        "\n",
        "When you're done remember to disconnect the runtime."
      ],
      "metadata": {
        "id": "qDxifmXfZmGV"
      }
    }
  ]
}